/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/* !
 * \file kernel_micro_vec_binary_scalar_impl.h
 * \brief
 */
#ifndef ASCENDC_MODULE_MICRO_VEC_BINARY_SCALAR_IMPL_H
#define ASCENDC_MODULE_MICRO_VEC_BINARY_SCALAR_IMPL_H

namespace AscendC {
namespace MicroAPI {
template <typename T = DefaultType, typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void AddsComplexTraitTwoImpl(RegT &dstReg, RegT &srcReg, const ScalarT& scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(CheckRegTrait<RegT, RegTraitNumTwo>(), "RegT should be RegTraitNumTwo");
    constexpr auto modeValue = GetMaskMergeMode<mode>();
    RegTensor<typename ActualT::EleType> realSrc1tReg;
    RegTensor<typename ActualT::EleType> imagSrc1Reg;
    ActualT scalarAux(scalar);
    Duplicate(realSrc1tReg, scalarAux.real, mask);
    Duplicate(imagSrc1Reg, scalarAux.imag, mask);

    vadd((RegTensor<typename ActualT::EleType> &)dstReg.reg[0], (RegTensor<typename ActualT::EleType> &)srcReg.reg[0],
        realSrc1tReg, mask, modeValue);
    vadd((RegTensor<typename ActualT::EleType> &)dstReg.reg[1], (RegTensor<typename ActualT::EleType> &)srcReg.reg[1],
        imagSrc1Reg, mask, modeValue);
}

template <typename T = DefaultType, typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void AddsComplexTraitOneImpl(RegT &dstReg, RegT &srcReg, const ScalarT& scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(CheckRegTrait<RegT, RegTraitNumOne>(), "RegT should be RegTraitNumOne");
    constexpr auto modeValue = GetMaskMergeMode<mode>();
    MaskReg maskTrait2;
    MaskPack(maskTrait2, mask);
    RegTensor<ActualT, RegTraitNumTwo> traitTwoSrcReg;
    RegTensor<ActualT, RegTraitNumTwo> traitTwoDstReg;
    TraitOneToTaitTwoTmpl<RegTensor<ActualT, RegTraitNumTwo>, RegTensor<ActualT, RegTraitNumOne>, typename ActualT::EleType>(
        traitTwoSrcReg, srcReg);
    AddsComplexTraitTwoImpl(traitTwoDstReg, traitTwoSrcReg, scalar, maskTrait2);
    TraitTwoToTaitOneTmpl<RegTensor<ActualT, RegTraitNumOne>, RegTensor<ActualT, RegTraitNumTwo>, typename ActualT::EleType>(
        dstReg, traitTwoDstReg);
}

template <typename T = DefaultType, typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void AddsImpl(RegT &dstReg, RegT &srcReg, ScalarT scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(
        SupportType<ActualT, uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, half, float,
            bfloat16_t, uint64_t, int64_t, complex32, complex64>(),
        "current data type is not supported on current device!");
    static_assert(
        SupportType<ScalarT, uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, half, float,
            bfloat16_t, uint64_t, int64_t, complex32, complex64>(),
        "current scalar data type is not supported on current device!");
    static_assert(Std::is_convertible<ScalarT, ActualT>(), "scalar data type could be converted to RegTensor data type");
    static_assert(SupportEnum<mode, MaskMergeMode::ZEROING>(),
        "current Adds api only supported Mode ZEROING on current device!");


    if constexpr(SupportType<ActualT, complex32>()) {
        if constexpr (CheckRegTrait<RegT, RegTraitNumTwo>()) {
            AddsComplexTraitTwoImpl(dstReg, srcReg, scalar, mask);
        } else {
            AddsComplexTraitOneImpl(dstReg, srcReg, scalar, mask);
        }
    } else if constexpr (sizeof(ActualT) == 8) {
        if constexpr(SupportType<ActualT, complex64>()) {
            if constexpr (CheckRegTrait<RegT, RegTraitNumTwo>()) {
                AddsComplexTraitTwoImpl(dstReg, srcReg, scalar, mask);
            } else {
                MaskReg maskTrait2;
                MaskPack(maskTrait2, mask);
                RegTensor<ActualT, RegTraitNumTwo> traitTwoSrcReg;
                RegTensor<ActualT, RegTraitNumTwo> traitTwoDstReg;
                B64TraitOneToTaitTwo(traitTwoSrcReg, srcReg);
                AddsComplexTraitTwoImpl(traitTwoDstReg, traitTwoSrcReg, scalar, maskTrait2);
                B64TraitTwoToTaitOne(dstReg, traitTwoDstReg);
            }
        } else {
            RegT srcReg1;
            Duplicate(srcReg1, scalar, mask);
            Add(dstReg, srcReg, srcReg1, mask);
        }
    } else {
        constexpr auto modeValue = GetMaskMergeMode<mode>();
        vadds(dstReg, srcReg, scalar, mask, modeValue);
    }
}

template <typename T = DefaultType, typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void MulsKernel(RegT &dstReg, RegT &srcReg, const ScalarT& scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(CheckRegTrait<RegT, RegTraitNumTwo>(), "RegT should be RegTraitNumTwo");
    constexpr auto modeValue = GetMaskMergeMode<mode>();
    RegTensor<typename ActualT::EleType> src1Real;
    RegTensor<typename ActualT::EleType> src1Imag;
    ScalarT scalarAux(scalar);
    Duplicate(src1Real, scalarAux.real, mask);
    Duplicate(src1Imag, scalarAux.imag, mask);

    RegTensor<typename ActualT::EleType> &src0Real= (RegTensor<typename ActualT::EleType> &)srcReg.reg[0];
    RegTensor<typename ActualT::EleType> &src0Imag = (RegTensor<typename ActualT::EleType> &)srcReg.reg[1];
    RegTensor<typename ActualT::EleType> &dstReal = (RegTensor<typename ActualT::EleType> &)dstReg.reg[0];
    RegTensor<typename ActualT::EleType> &dstImag = (RegTensor<typename ActualT::EleType> &)dstReg.reg[1];
    RegTensor<typename ActualT::EleType> e;
    RegTensor<typename ActualT::EleType> f;
    RegTensor<typename ActualT::EleType> g;
    RegTensor<typename ActualT::EleType> h;
    vmul(e, src0Real, src1Real, mask, modeValue);
    vmul(f, src0Imag, src1Imag, mask, modeValue);
    vmul(g, src0Imag, src1Real, mask, modeValue);
    vmul(h, src0Real, src1Imag, mask, modeValue);
    vsub(dstReal, e, f, mask, modeValue);
    vadd(dstImag, g, h, mask, modeValue);
}

template <typename T = DefaultType, typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void MulsImpl(RegT &dstReg, RegT &srcReg, ScalarT scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(SupportType<ActualT, uint16_t, int16_t, uint32_t, int32_t, half, float,
        uint64_t, int64_t, complex32, complex64>(),
        "current data type is not supported on current device!");
    static_assert(SupportType<ScalarT, uint16_t, int16_t, uint32_t, int32_t, half, float,
        uint64_t, int64_t, complex32, complex64>(),
        "current scalar data type is not supported on current device!");
    static_assert(Std::is_convertible<ScalarT, ActualT>(), "scalar data type could be converted to RegTensor data type");
    static_assert(SupportEnum<mode, MaskMergeMode::ZEROING>(),
        "current Muls api only supported Mode ZEROING on current device!");
    constexpr auto modeValue = GetMaskMergeMode<mode>();
    if constexpr (sizeof(ActualT) != 8) {
        if constexpr(SupportType<ActualT, complex32>()) {
            if constexpr (CheckRegTrait<RegT, RegTraitNumTwo>()) {
                MulsKernel<T, ScalarT, mode, RegTensor<ActualT, RegTraitNumTwo>>(dstReg, srcReg, scalar, mask);
            } else {
                MaskReg maskTrait2;
                MaskPack(maskTrait2, mask);
                RegTensor<ActualT, RegTraitNumTwo> traitTwoSrcReg;
                RegTensor<ActualT, RegTraitNumTwo> traitTwoDstReg;
                B32TraitOneToTaitTwo(traitTwoSrcReg, srcReg);
                MulsKernel<T, ScalarT, mode, RegTensor<ActualT, RegTraitNumTwo>>(
                    traitTwoDstReg, traitTwoSrcReg, scalar, maskTrait2);
                B32TraitTwoToTaitOne(dstReg, traitTwoDstReg);
            }
        } else {
            vmuls(dstReg, srcReg, scalar, mask, modeValue);
        }
    } else {
        if constexpr(SupportType<ActualT, complex64>()) {
            if constexpr (CheckRegTrait<RegT, RegTraitNumTwo>()) {
                MulsKernel<T, ScalarT, mode, RegTensor<ActualT, RegTraitNumTwo>>(dstReg, srcReg, scalar, mask);
            } else {
                MaskReg maskTrait2;
                MaskPack(maskTrait2, mask);
                RegTensor<ActualT, RegTraitNumTwo> traitTwoSrcReg;
                RegTensor<ActualT, RegTraitNumTwo> traitTwoDstReg;
                B64TraitOneToTaitTwo(traitTwoSrcReg, srcReg);
                MulsKernel<T, ScalarT, mode, RegTensor<ActualT, RegTraitNumTwo>>(
                    traitTwoDstReg, traitTwoSrcReg, scalar, maskTrait2);
                B64TraitTwoToTaitOne(dstReg, traitTwoDstReg);
            }
        } else {
            RegT srcReg1;
            Duplicate(srcReg1, scalar, mask);
            Mul(dstReg, srcReg, srcReg1, mask);
        }
    }
}

template <typename T = DefaultType, typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void MaxsImpl(RegT &dstReg, RegT &srcReg, ScalarT scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(
        SupportType<ActualT, uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, half, float,
            bfloat16_t, uint64_t, int64_t>(),
        "current data type is not supported on current device!");
    static_assert(
        SupportType<ScalarT, uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, half, float,
            bfloat16_t, uint64_t, int64_t>(),
        "current scalar data type is not supported on current device!");
    static_assert(Std::is_convertible<ScalarT, ActualT>(), "scalar data type could be converted to RegTensor data type");
    static_assert(SupportEnum<mode, MaskMergeMode::ZEROING>(),
        "current Maxs api only supported Mode ZEROING on current device!");
    constexpr auto modeValue = GetMaskMergeMode<mode>();
    if constexpr (sizeof(ActualT) != 8) {
        vmaxs(dstReg, srcReg, scalar, mask, modeValue);
    } else {
        RegT srcReg1;
        Duplicate(srcReg1, scalar, mask);
        Max(dstReg, srcReg, srcReg1, mask);
    }
}

template <typename T = DefaultType, typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void MinsImpl(RegT &dstReg, RegT &srcReg, ScalarT scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(
        SupportType<ActualT, uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, half, float, bfloat16_t, uint64_t, int64_t>(),
        "current data type is not supported on current device!");
    static_assert(
        SupportType<ScalarT, uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, half, float, bfloat16_t, uint64_t, int64_t>(),
        "current scalar data type is not supported on current device!");
    static_assert(Std::is_convertible<ScalarT, ActualT>(), "scalar data type could be converted to RegTensor data type");
    static_assert(SupportEnum<mode, MaskMergeMode::ZEROING>(),
        "current Mins api only supported Mode ZEROING on current device!");
    constexpr auto modeValue = GetMaskMergeMode<mode>();
    if constexpr (sizeof(ActualT) != 8) {
        vmins(dstReg, srcReg, scalar, mask, modeValue);
    } else {
        RegT srcReg1;
        Duplicate(srcReg1, scalar, mask);
        Min(dstReg, srcReg, srcReg1, mask);
    }
}

template <typename T = DefaultType, typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void ShiftLeftsImpl(RegT &dstReg, RegT &srcReg, ScalarT scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(SupportType<ActualT, uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t>(),
        "current data type is not supported on current device!");
    static_assert(SupportType<ScalarT, int16_t>(), "current scalar data type is not supported on current device!");
    static_assert(SupportEnum<mode, MaskMergeMode::ZEROING>(),
        "current ShiftLefts api only supported Mode ZEROING on current device!");
    constexpr auto modeValue = GetMaskMergeMode<mode>();
    if constexpr (sizeof(ActualT) != 8) {
        constexpr auto modeValue = GetMaskMergeMode<mode>();
        vshls(dstReg, srcReg, scalar, mask, modeValue);
    } else {
        if constexpr (CheckRegTrait<RegT, RegTraitNumTwo>()) {
            RegT dstTemp;
            ShiftLeftsB64Impl(dstTemp, srcReg, scalar, mask);
            dstReg = dstTemp;
        } else if constexpr (CheckRegTrait<RegT, RegTraitNumOne>()) {
            MaskReg maskTrait2;
            MaskPack(maskTrait2, mask);
            RegTensor<ActualT, RegTraitNumTwo> traitTwoSrcReg0;
            RegTensor<ActualT, RegTraitNumTwo> traitTwoDstReg;
            B64TraitOneToTaitTwo(traitTwoSrcReg0, srcReg);
            ShiftLeftsB64Impl(traitTwoDstReg, traitTwoSrcReg0, scalar, maskTrait2);
            B64TraitTwoToTaitOne(dstReg, traitTwoDstReg);
        }
    }
}

template <typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void ShiftLeftsB64Impl(RegT &dstReg, RegT &srcReg, ScalarT scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    constexpr auto modeValue = GetMaskMergeMode<mode>();
    int16_t B32Width = 32;
    if constexpr (std::is_same_v<ActualT, uint64_t>) {
        RegTensor<uint32_t> tmpReg0;
        RegTensor<uint32_t> tmpReg1;
        vshls(tmpReg0, (RegTensor<uint32_t> &)srcReg.reg[0], scalar, mask, modeValue);
        vshls(tmpReg1, (RegTensor<uint32_t> &)srcReg.reg[1], B32Width + scalar, mask, modeValue);
        Or((RegTensor<uint32_t> &)dstReg.reg[0], tmpReg0, tmpReg1, mask);
        vshrs(tmpReg0, (RegTensor<uint32_t> &)srcReg.reg[0], B32Width - scalar, mask, modeValue);
        vshls(tmpReg1, (RegTensor<uint32_t> &)srcReg.reg[1], scalar, mask, modeValue);
        Or((RegTensor<uint32_t> &)dstReg.reg[1], tmpReg0, tmpReg1, mask);
    } else if constexpr (std::is_same_v<ActualT, int64_t>) {
        RegTensor<int32_t> tmpReg0;
        RegTensor<int32_t> tmpReg1;
        vshls((RegTensor<uint32_t> &)tmpReg0, (RegTensor<uint32_t> &)srcReg.reg[0], scalar, mask, modeValue);
        vshls((RegTensor<uint32_t> &)tmpReg1, (RegTensor<uint32_t> &)srcReg.reg[1], B32Width + scalar, mask, modeValue);
        Or((RegTensor<int32_t> &)dstReg.reg[0], tmpReg0, tmpReg1, mask);
        vshrs((RegTensor<uint32_t> &)tmpReg0, (RegTensor<uint32_t> &)srcReg.reg[0], B32Width - scalar, mask, modeValue);
        vshls(tmpReg1, (RegTensor<int32_t> &)srcReg.reg[1], scalar, mask, modeValue);
        Or((RegTensor<int32_t> &)dstReg.reg[1], tmpReg0, tmpReg1, mask);
    }
}

template <typename T = DefaultType, typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void ShiftRightsImpl(RegT &dstReg, RegT &srcReg, ScalarT scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(SupportType<ActualT, uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t>(),
        "current data type is not supported on current device!");
    static_assert(SupportType<ScalarT, int16_t>(), "current scalar data type is not supported on current device!");
    static_assert(SupportEnum<mode, MaskMergeMode::ZEROING>(),
        "current ShiftRights api only supported Mode ZEROING on current device!");
    if constexpr (sizeof(ActualT) < 8) {
        constexpr auto modeValue = GetMaskMergeMode<mode>();
        vshrs(dstReg, srcReg, scalar, mask, modeValue);
    } else if constexpr (sizeof(ActualT) == 8) {
        if constexpr (CheckRegTrait<RegT, RegTraitNumTwo>()) {
            RegT dstTemp;
            ShiftRightsB64Impl(dstTemp, srcReg, scalar, mask);
            dstReg = dstTemp;
        } else if constexpr (CheckRegTrait<RegT, RegTraitNumOne>()) {
            MaskReg maskTrait2;
            MaskPack(maskTrait2, mask);
            RegTensor<ActualT, RegTraitNumTwo> traitTwoSrcReg0;
            RegTensor<ActualT, RegTraitNumTwo> traitTwoDstReg;
            B64TraitOneToTaitTwo(traitTwoSrcReg0, srcReg);
            ShiftRightsB64Impl(traitTwoDstReg, traitTwoSrcReg0, scalar, maskTrait2);
            B64TraitTwoToTaitOne(dstReg, traitTwoDstReg);
        }
    }
}

template <typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void ShiftRightsB64Impl(RegT &dstReg, RegT &srcReg, ScalarT scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    constexpr auto modeValue = GetMaskMergeMode<mode>();
    int16_t B32Width = 32;
    if constexpr (std::is_same_v<ActualT, uint64_t>) {
        RegTensor<uint32_t> tmpReg0;
        RegTensor<uint32_t> tmpReg1;
        vshrs(tmpReg0, (RegTensor<uint32_t> &)srcReg.reg[0], scalar, mask, modeValue);
        vshls(tmpReg1, (RegTensor<uint32_t> &)srcReg.reg[1], B32Width - scalar, mask, modeValue);
        Or((RegTensor<uint32_t> &)dstReg.reg[0], tmpReg0, tmpReg1, mask);
        vshrs(tmpReg0, (RegTensor<uint32_t> &)srcReg.reg[0], B32Width + scalar, mask, modeValue);
        vshrs(tmpReg1, (RegTensor<uint32_t> &)srcReg.reg[1], scalar, mask, modeValue);
        Or((RegTensor<uint32_t> &)dstReg.reg[1], tmpReg0, tmpReg1, mask);
    } else if constexpr (std::is_same_v<ActualT, int64_t>) {
        RegTensor<int32_t> tmpReg0;
        RegTensor<int32_t> tmpReg1;
        vshrs((RegTensor<uint32_t> &)tmpReg0, (RegTensor<uint32_t> &)srcReg.reg[0], scalar, mask, modeValue);
        vshls((RegTensor<int32_t> &)tmpReg1, (RegTensor<int32_t> &)srcReg.reg[1], B32Width - scalar, mask, modeValue);
        Or((RegTensor<int32_t> &)dstReg.reg[0], tmpReg0, tmpReg1, mask);
        vshrs((RegTensor<uint32_t> &)tmpReg0, (RegTensor<uint32_t> &)srcReg.reg[0], B32Width + scalar, mask, modeValue);
        vshrs(tmpReg1, (RegTensor<int32_t> &)srcReg.reg[1], scalar, mask, modeValue);
        Or((RegTensor<int32_t> &)dstReg.reg[1], tmpReg0, tmpReg1, mask);
    }
}

template <typename T = DefaultType, typename ScalarT, MaskMergeMode mode = MaskMergeMode::ZEROING, typename RegT>
__simd_callee__ inline void LeakyReluImpl(RegT &dstReg, RegT &srcReg, ScalarT scalar, MaskReg &mask)
{
    using ActualT = typename RegT::ActualT;
    static_assert(std::is_same_v<T, DefaultType> || std::is_same_v<T, ActualT>, "T type is not correct!");
    static_assert(SupportType<ActualT, half, float>(), "current data type is not supported on current device!");
    static_assert(SupportType<ScalarT, half, float>(), "current scalar data type is not supported on current device!");
    static_assert(Std::is_convertible<ScalarT, ActualT>(), "scalar data type could be converted to RegTensor data type");
    static_assert(SupportEnum<mode, MaskMergeMode::ZEROING>(),
        "current LeakyRelu api only supported Mode ZEROING on current device!");
    constexpr auto modeValue = GetMaskMergeMode<mode>();
    vlrelu(dstReg, srcReg, scalar, mask, modeValue);
}
} // namespace MicroAPI
} // namespace AscendC
#endif // ASCENDC_MODULE_MICRO_VEC_BINARY_SCALAR_IMPL_H
